/*
   @Copyright:LintCode
   @Author:   tjyemail
   @Problem:  http://www.lintcode.com/problem/binary-tree-pruning
   @Language: C++
   @Datetime: 19-05-08 17:29
   */

/**
 * Definition of TreeNode:
 * class TreeNode {
 * public:
 *     int val;
 *     TreeNode *left, *right;
 *     TreeNode(int val) {
 *         this->val = val;
 *         this->left = this->right = NULL;
 *     }
 * }
 */

class Solution {
	bool postorder(TreeNode *root){
		if(root==NULL) return true;
		if(postorder(root->left)) root->left=NULL;
		if(postorder(root->right)) root->right=NULL;
		return root->val==0 && root->left==NULL && root->right==NULL;
	}
public:
	/**
	 * @param root: the root
	 * @return: the same tree where every subtree (of the given tree) not containing a 1 has been removed
	 */
	TreeNode * pruneTree(TreeNode * root) {
		// Write your code here
		postorder(root);
		if(root->val==0 && root->left==NULL && root->right==NULL) return NULL;
		return root;
	}
};
